{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The idea to create this model was sparked by a tweet: https://twitter.com/fchollet/status/711594805692792832\n",
    "\n",
    "I studied the following resources to be able to finish the model:\n",
    "0. http://artint.info/html/ArtInt_265.html\n",
    "1. https://edersantana.github.io/articles/keras_rl/\n",
    "2. http://www.nervanasys.com/demystifying-deep-reinforcement-learning/\n",
    "3. http://keras.io/\n",
    "\n",
    "Here's a gif of the model playing a few games of catching fruit:\n",
    "<img src=\"files/fruit.gif\" />\n",
    "\n",
    "Btw, I really, *really* liked Eder Santana's idea to apply big idea's to toy examples. I hope to do this more often to feed my fingerspitzengefühl."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Populating the interactive namespace from numpy and matplotlib\n"
     ]
    }
   ],
   "source": [
    "from random import sample as rsample\n",
    "\n",
    "import numpy as np\n",
    "\n",
    "from keras.models import Sequential\n",
    "from keras.layers.convolutional import Convolution2D\n",
    "from keras.layers.core import Dense, Flatten\n",
    "from keras.optimizers import SGD, RMSprop\n",
    "\n",
    "%pylab inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "GRID_SIZE = 10\n",
    "\n",
    "def episode():\n",
    "    \"\"\" \n",
    "    Coroutine of episode. \n",
    "    \n",
    "    Action has to be explicitly send to this coroutine.\n",
    "    \"\"\"\n",
    "    x, y, z = (\n",
    "        np.random.randint(0, GRID_SIZE),  # X of fruit\n",
    "        0,  # Y of dot\n",
    "        np.random.randint(1, GRID_SIZE - 1)  # X of basket\n",
    "    )\n",
    "    while True:\n",
    "        X = np.zeros((GRID_SIZE, GRID_SIZE))  # Reset grid\n",
    "        X[y, x] = 1.  # Draw fruit\n",
    "        bar = range(z - 1, z + 2)\n",
    "        X[-1, bar] = 1.  # Draw basket\n",
    "        \n",
    "        # End of game is known when fruit is at penultimate line of grid.\n",
    "        # End represents either a win or a loss\n",
    "        end = int(y >= GRID_SIZE - 2)\n",
    "        if end and x not in bar:\n",
    "            end *= -1\n",
    "            \n",
    "        action = yield X[np.newaxis], end    \n",
    "        if end:\n",
    "            break\n",
    "\n",
    "        z = min(max(z + action, 1), GRID_SIZE - 2)\n",
    "        y += 1\n",
    "\n",
    "            \n",
    "def experience_replay(batch_size):\n",
    "    \"\"\"\n",
    "    Coroutine of experience replay.\n",
    "    \n",
    "    Provide a new experience by calling send, which in turn yields \n",
    "    a random batch of previous replay experiences.\n",
    "    \"\"\"\n",
    "    memory = []\n",
    "    while True:\n",
    "        experience = yield rsample(memory, batch_size) if batch_size <= len(memory) else None\n",
    "        memory.append(experience)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 100, loss: 0.026815\n",
      "Epoch 200, loss: 0.023206\n",
      "Epoch 300, loss: 0.022220\n",
      "Epoch 400, loss: 0.013861\n",
      "Epoch 500, loss: 0.010943\n",
      "Epoch 600, loss: 0.008557\n",
      "Epoch 700, loss: 0.010435\n",
      "Epoch 800, loss: 0.003742\n",
      "Epoch 900, loss: 0.003236\n",
      "Epoch 1000, loss: 0.005833\n"
     ]
    }
   ],
   "source": [
    "nb_epochs = 1000\n",
    "batch_size = 128\n",
    "epsilon = .8\n",
    "gamma = .8\n",
    "\n",
    "# Recipe of deep reinforcement learning model\n",
    "model = Sequential()\n",
    "model.add(Convolution2D(16, nb_row=3, nb_col=3, input_shape=(1, GRID_SIZE, GRID_SIZE), activation='relu'))\n",
    "model.add(Convolution2D(16, nb_row=3, nb_col=3, activation='relu'))\n",
    "model.add(Flatten())\n",
    "model.add(Dense(100, activation='relu'))\n",
    "model.add(Dense(3))\n",
    "model.compile(RMSprop(), 'MSE')\n",
    "\n",
    "exp_replay = experience_replay(batch_size)\n",
    "exp_replay.next()  # Start experience-replay coroutine\n",
    "\n",
    "for i in xrange(nb_epochs):\n",
    "    ep = episode()\n",
    "    S, won = ep.next()  # Start coroutine of single entire episode\n",
    "    loss = 0.\n",
    "    try:\n",
    "        while True:\n",
    "            action = np.random.randint(-1, 2) \n",
    "            if np.random.random() > epsilon:\n",
    "                # Get the index of the maximum q-value of the model.\n",
    "                # Subtract one because actions are either -1, 0, or 1\n",
    "                action = np.argmax(model.predict(S[np.newaxis]), axis=-1)[0] - 1\n",
    "\n",
    "            S_prime, won = ep.send(action)\n",
    "            experience = (S, action, won, S_prime)\n",
    "            S = S_prime\n",
    "            \n",
    "            batch = exp_replay.send(experience)\n",
    "            if batch:\n",
    "                inputs = []\n",
    "                targets = []\n",
    "                for s, a, r, s_prime in batch:\n",
    "                    # The targets of unchosen actions are the q-values of the model,\n",
    "                    # so that the corresponding errors are 0. The targets of chosen actions\n",
    "                    # are either the rewards, in case a terminal state has been reached, \n",
    "                    # or future discounted q-values, in case episodes are still running.\n",
    "                    t = model.predict(s[np.newaxis]).flatten()\n",
    "                    t[a + 1] = r\n",
    "                    if not r:\n",
    "                        t[a + 1] = r + gamma * model.predict(s_prime[np.newaxis]).max(axis=-1)\n",
    "                    targets.append(t)\n",
    "                    inputs.append(s)\n",
    "                \n",
    "                loss += model.train_on_batch(np.array(inputs), np.array(targets))\n",
    "\n",
    "    except StopIteration:\n",
    "        pass\n",
    "    \n",
    "    if (i + 1) % 100 == 0:\n",
    "        print 'Epoch %i, loss: %.6f' % (i + 1, loss)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPYAAAD7CAYAAABZjGkWAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAChhJREFUeJzt3VuInXe5x/Hvz0Tdxnq49BRYRSpUUTC4JVSE7E0vQhG9\nEWxRFC+88lC9EK0XFnspiArihdotFKvCjiIKHlA0BhHE2rRqkmKVPdg0BEWxVa8SfbyYxc7Ymazz\nWu/Mw/cDC9Za82fWQ5jv/N/1zpuZVBWSenna0ANIWj3DlhoybKkhw5YaMmypIcOWGjq87CdI4s/L\npAFVVZ763NJhb7t7jrWngROredmNOM3BmheceRNOM/y8Bdyz50c8FJcaMmypoQHCHm3+JZcyGnqA\nBYyGHmABo6EHmNNo6AEmMuypRkMPsIDR0AMsYDT0AHMaDT3ARFPDTnIyySNJHk3yoU0MJWk5E8NO\ncgj4DHASeDlwR5KbNzGYpMVN27FfC/y2qraq6grwVeBN6x9L0jKmhf1i4LEdjy+On5O0j00L26vK\npANo2pVnjwNHdzw+yvau/RSnd9wfsd/PGEoH19b4BpP23WlhPwDclGQEXALeAtyxe9mJOYeTtJgR\n1zbOAs7suWpi2FV1Ncl7gO8Bh4B7q+rCymaUtBZT/xNIVX0H+M4GZpG0Il4rLjVk2FJDhi01ZNhS\nQ4YtNWTYUkOGLTVk2FJDhi01ZNhSQ4YtNWTYUkOGLTVk2FJDhi01ZNhSQ4YtNWTYUkOGLTVk2FJD\nhi01ZNhSQ4YtNWTYUkOGLTVk2FJDhi01ZNhSQ4YtNWTYUkOGLTVk2FJDhi01ZNhSQ4YtNWTYUkNT\nw05yNMmPkpxL8usk79vEYJIWd3iGNVeAD1TVQ0luAH6R5PtVdWHNs0la0NQdu6ouV9VD4/t/Ay4A\nL1r3YJIWN9d77CQj4NXAz9YxjKTVmOVQHIDxYfgp4M7xzr3D6R33R+ObpNXbGt8A6rqrZgo7ydOB\nrwFfqqpv7F5xYr7ZJC1oxLWNs4Aze66a5ax4gHuB81X1qdUMJ2mdZnmP/TrgbcB/JTk7vp1c81yS\nljD1ULyqfoIXskgHisFKDRm21JBhSw0ZttSQYUsNGbbUkGFLDRm21JBhSw0ZttSQYUsNGbbUkGFL\nDRm21JBhSw0ZttSQYUsNGbbUkGFLDRm21JBhSw0ZttSQYUsNGbbUkGFLDRm21JBhSw0ZttSQYUsN\nGbbUkGFLDRm21JBhSw0ZttSQYUsNGbbU0ExhJzmU5GySb617IEnLm3XHvhM4D9QaZ5G0IlPDTvIS\n4DbgC0DWPpGkpc2yY38S+CDwzzXPImlFDk/6YJI3AH+oqrNJTlx/5ekd90fjm6TV2xrfYNI744lh\nA7cAb0xyG/AfwHOT3FdVb//3ZScWm1HSnEZc2zgLOLPnqomH4lX1kao6WlU3ArcDP9wdtaT9Zt6f\nY3tWXDoAph2K/7+q+jHw4zXOImlFvPJMasiwpYYMW2rIsKWGDFtqyLClhgxbasiwpYYMW2rIsKWG\nDFtqyLClhgxbasiwpYYMW2rIsKWGDFtqyLClhgxbasiwpYYMW2rIsKWGDFtqyLClhgxbasiwpYYM\nW2rIsKWGDFtqyLClhgxbasiwpYYMW2rIsKWGDFtqaGrYSZ6f5FSSC0nOJzm+icEkLe7wDGs+DXy7\nqt6c5DDw7DXPJGlJE8NO8jzg9VX1DoCqugo8sYnBJC1u2qH4jcAfk3wxyYNJPp/kyCYGk7S4aWEf\nBo4Bn62qY8DfgQ+vfSpJS5n2HvsicLGqfj5+fIo9wz694/5ofJO0elvjG0Bdd9XEsKvqcpLHkrys\nqn4D3Aqc273yxGIzSprTiGsbZwFn9lw1y1nx9wL3J3kG8DvgncsPJ2mdpoZdVQ8D/7mBWSStiFee\nSQ0ZttSQYUsNGbbUkGFLDRm21JBhSw0ZttSQYUsNGbbUkGFLDRm21JBhSw0ZttSQYUsNGbbUkGFL\nDRm21JBhSw0ZttSQYUsNGbbUkGFLDRm21JBhSw0ZttSQYUsNGbbUkGFLDc3yZ3SlfeWj3LOWz3sP\nH13L5x2CO7bUkGFLDRm21JBhSw0ZttSQYUsNTQ07yV1JziX5VZIvJ3nmJgaTtLiJYScZAe8CjlXV\nK4FDwO3rH0vSMqZdoPIkcAU4kuQfwBHg8bVPJWkpE3fsqvoz8Ang98Al4C9V9YNNDCZpcRN37CQv\nBd4PjIAngP9N8taquv/fV57ecX80vklava3xDaCuu2raofhrgJ9W1Z8AknwduAV4StgnFplQ0txG\nXNs4Cziz56ppZ8UfAY4neVaSALcC51czoKR1mfYe+2HgPuAB4Jfjpz+37qEkLWfqf9usqo8DH9/A\nLJJWxCvPpIYMW2rIsKWGDFtqyLClhgxbaihV178sbaZPkBTcvaJx1MXdfGzoEeZWZC2fd12//fQm\nfsOjfIWq2jW4O7bUkGFLDRm21JBhSw0ZttSQYUsNGbbUkGFLDRm21JBhSw0ZttSQYUsNGbbUkGFL\nDRm21JBhSw0ZttSQYUsNDRD21uZfcilbQw+wgK2hB5jb1tADzGlrwl+63A8Me6qtoQdYwNbQA8xt\na+gB5rQ19ABTeCguNTT1j/LNZp7Dkppz/dAO2rywH2a+xAvmWv9X/solnjPT2hdyeZGRZjDvv9ls\n61/ApflHmcFzePK6H1vRrx+WNJS9fv3w0mFL2n98jy01ZNhSQxsLO8nJJI8keTTJhzb1uotKcjTJ\nj5KcS/LrJO8beqZZJDmU5GySbw09yyySPD/JqSQXkpxPcnzomaZJctf46+JXSb6c5JlDz/RUGwk7\nySHgM8BJ4OXAHUlu3sRrL+EK8IGqegVwHHj3AZgZ4E7gPEOfFp/dp4FvV9XNwKuACwPPM1GSEfAu\n4FhVvRI4BNw+5Ex72dSO/Vrgt1W1VVVXgK8Cb9rQay+kqi5X1UPj+39j+wvuRcNONVmSlwC3AV+A\nNf2FuRVK8jzg9VX1PwBVdbWqnhh4rGmeZPub/pEkh4EjwOPDjrTbpsJ+MfDYjscXx88dCOPv0q8G\nfjbsJFN9Evgg8M+hB5nRjcAfk3wxyYNJPp/kyNBDTVJVfwY+AfweuAT8pap+MOxUu20q7INyWLhL\nkhuAU8Cd4517X0ryBuAPVXWWA7Bbjx0GjgGfrapjwN+BDw870mRJXgq8HxixfQR3Q5K3DjrUHjYV\n9uPA0R2Pj7K9a+9rSZ4OfA34UlV9Y+h5prgFeGOS/wO+Avx3kvsGnmmai8DFqvr5+PEptkPfz14D\n/LSq/lRVV4Gvs/1vv69sKuwHgJuSjJI8A3gL8M0NvfZCkgS4FzhfVZ8aep5pquojVXW0qm5k+2TO\nD6vq7UPPNUlVXQYeS/Ky8VO3AucGHGkWjwDHkzxr/DVyK9snK/eVFV0rPllVXU3yHuB7bJ9FvLeq\n9vXZT+B1wNuAXyY5O37urqr67oAzzeOgvP15L3D/+Bv+74B3DjzPRFX18PhI6AG2z2U8CHxu2Kl2\n85JSqSGvPJMaMmypIcOWGjJsqSHDlhoybKkhw5YaMmypoX8BVnn6vPUZM7QAAAAASUVORK5CYII=\n",
      "text/plain": [
       "<matplotlib.figure.Figure at 0x10c1b81d0>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "def save_img():\n",
    "    frame = 0\n",
    "    while True:\n",
    "        screen = (yield)\n",
    "        plt.imshow(screen[0], interpolation='none')\n",
    "        plt.savefig('images/%03i.png' % frame)\n",
    "        frame += 1\n",
    "    \n",
    "img_saver = save_img()\n",
    "img_saver.next()\n",
    "\n",
    "for _ in xrange(10):\n",
    "    g = episode()\n",
    "    S, _ = g.next()\n",
    "    img_saver.send(S)\n",
    "    try:\n",
    "        while True:\n",
    "            act = np.argmax(model.predict(S[np.newaxis]), axis=-1)[0] - 1\n",
    "            S, _ = g.send(act)\n",
    "            img_saver.send(S)\n",
    "\n",
    "    except StopIteration:\n",
    "        g.close()\n",
    "\n",
    "img_saver.close()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 2",
   "language": "python",
   "name": "python2"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
